#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import os
import random
import xxhash
from datetime import datetime

from api.db.db_utils import bulk_insert_into_db
from deepdoc.parser import PdfParser
from peewee import JOIN
from api.db.db_models import DB, File2Document, File
from api.db import StatusEnum, FileType, TaskStatus
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, get_uuid
from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search


# 定义 trim_header_by_lines 函数，用于按最大长度截断文本并尽量保留换行结构
def trim_header_by_lines(text: str, max_length) -> str:
    # Trim header text to maximum length while preserving line breaks
    # Args:
    #     text: Input text to trim
    #     max_length: Maximum allowed length
    # Returns:
    #     Trimmed text
    # 先计算输入文本的总长度
    len_text = len(text)
    # 如果文本长度未超过限制，直接返回原文本
    if len_text <= max_length:
        return text
    # 遍历文本字符，寻找一个合适的换行符位置进行截断
    for i in range(len_text):
        # 当找到换行符，并且从该位置到结尾不超过最大长度时，截断并返回后面的内容
        if text[i] == '\n' and len_text - i <= max_length:
            return text[i + 1:]
    # 如果没有找到合适的换行符，则直接返回原始文本（不进行截断）
    return text


class TaskService(CommonService):
    """Service class for managing document processing tasks.
    
    This class extends CommonService to provide specialized functionality for document
    processing task management, including task creation, progress tracking, and chunk
    management. It handles various document types (PDF, Excel, etc.) and manages their
    processing lifecycle.
    
    The class implements a robust task queue system with retry mechanisms and progress
    tracking, supporting both synchronous and asynchronous task execution.
    
    Attributes:
        model: The Task model class for database operations.
    """
    model = Task

    @classmethod
    @DB.connection_context()
    def get_task(cls, task_id):
        """Retrieve detailed task information by task ID.
    
        This method fetches comprehensive task details including associated document,
        knowledge base, and tenant information. It also handles task retry logic and
        progress updates.
    
        Args:
            task_id (str): The unique identifier of the task to retrieve.
    
        Returns:
            dict: Task details dictionary containing all task information and related metadata.
                 Returns None if task is not found or has exceeded retry limit.
        """
        fields = [
            cls.model.id,
            cls.model.doc_id,
            cls.model.from_page,
            cls.model.to_page,
            cls.model.retry_count,
            Document.kb_id,
            Document.parser_id,
            Document.parser_config,
            Document.name,
            Document.type,
            Document.location,
            Document.size,
            Knowledgebase.tenant_id,
            Knowledgebase.language,
            Knowledgebase.embd_id,
            Knowledgebase.pagerank,
            Knowledgebase.parser_config.alias("kb_parser_config"),
            Tenant.img2txt_id,
            Tenant.asr_id,
            Tenant.llm_id,
            cls.model.update_time,
        ]
        docs = (
            cls.model.select(*fields)
                .join(Document, on=(cls.model.doc_id == Document.id))
                .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
                .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
                .where(cls.model.id == task_id)
        )
        docs = list(docs.dicts())
        if not docs:
            return None

        msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task has been received."
        prog = random.random() / 10.0
        if docs[0]["retry_count"] >= 3:
            msg = "\nERROR: Task is abandoned after 3 times attempts."
            prog = -1

        cls.model.update(
            progress_msg=cls.model.progress_msg + msg,
            progress=prog,
            retry_count=docs[0]["retry_count"] + 1,
        ).where(cls.model.id == docs[0]["id"]).execute()

        if docs[0]["retry_count"] >= 3:
            return None

        return docs[0]

    @classmethod
    @DB.connection_context()
    def get_tasks(cls, doc_id: str):
        """Retrieve all tasks associated with a document.
    
        This method fetches all processing tasks for a given document, ordered by page
        number and creation time. It includes task progress and chunk information.
    
        Args:
            doc_id (str): The unique identifier of the document.
    
        Returns:
            list[dict]: List of task dictionaries containing task details.
                       Returns None if no tasks are found.
        """
        fields = [
            cls.model.id,
            cls.model.from_page,
            cls.model.progress,
            cls.model.digest,
            cls.model.chunk_ids,
        ]
        tasks = (
            cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
                .where(cls.model.doc_id == doc_id)
        )
        tasks = list(tasks.dicts())
        if not tasks:
            return None
        return tasks

    @classmethod
    @DB.connection_context()
    def update_chunk_ids(cls, id: str, chunk_ids: str):
        """Update the chunk IDs associated with a task.
    
        This method updates the chunk_ids field of a task, which stores the IDs of
        processed document chunks in a space-separated string format.
    
        Args:
            id (str): The unique identifier of the task.
            chunk_ids (str): Space-separated string of chunk identifiers.
        """
        cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()

    @classmethod
    @DB.connection_context()
    def get_ongoing_doc_name(cls):
        """Get names of documents that are currently being processed.
    
        This method retrieves information about documents that are in the processing state,
        including their locations and associated IDs. It uses database locking to ensure
        thread safety when accessing the task information.
    
        Returns:
            list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
                        for documents currently being processed. Returns empty list if
                        no documents are being processed.
        """
        with DB.lock("get_task", -1):
            docs = (
                cls.model.select(
                    *[Document.id, Document.kb_id, Document.location, File.parent_id]
                )
                    .join(Document, on=(cls.model.doc_id == Document.id))
                    .join(
                    File2Document,
                    on=(File2Document.document_id == Document.id),
                    join_type=JOIN.LEFT_OUTER,
                )
                    .join(
                    File,
                    on=(File2Document.file_id == File.id),
                    join_type=JOIN.LEFT_OUTER,
                )
                    .where(
                    Document.status == StatusEnum.VALID.value,
                    Document.run == TaskStatus.RUNNING.value,
                    ~(Document.type == FileType.VIRTUAL.value),
                    cls.model.progress < 1,
                    cls.model.create_time >= current_timestamp() - 1000 * 600,
                )
            )
            docs = list(docs.dicts())
            if not docs:
                return []

            return list(
                set(
                    [
                        (
                            d["parent_id"] if d["parent_id"] else d["kb_id"],
                            d["location"],
                        )
                        for d in docs
                    ]
                )
            )

    @classmethod
    @DB.connection_context()
    def do_cancel(cls, id):
        """Check if a task should be cancelled based on its document status.
    
        This method determines whether a task should be cancelled by checking the
        associated document's run status and progress. A task should be cancelled
        if its document is marked for cancellation or has negative progress.
    
        Args:
            id (str): The unique identifier of the task to check.
    
        Returns:
            bool: True if the task should be cancelled, False otherwise.
        """
        task = cls.model.get_by_id(id)
        _, doc = DocumentService.get_by_id(task.doc_id)
        return doc.run == TaskStatus.CANCEL.value or doc.progress < 0

    @classmethod
    @DB.connection_context()
    def update_progress(cls, id, info):
        """Update the progress information for a task.
    
        This method updates both the progress message and completion percentage of a task.
        It handles platform-specific behavior (macOS vs others) and uses database locking
        when necessary to ensure thread safety.
    
        Args:
            id (str): The unique identifier of the task to update.
            info (dict): Dictionary containing progress information with keys:
                        - progress_msg (str, optional): Progress message to append
                        - progress (float, optional): Progress percentage (0.0 to 1.0)
        """
        # 如果运行在 macOS 平台下，采用无锁更新策略，分别处理消息和进度字段
        if os.environ.get("MACOS"):
            # 如果提供了进度消息，获取任务对象并拼接新的消息内容，保留最多3000行后更新数据库
            if info["progress_msg"]:
                task = cls.model.get_by_id(id)
                progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
                cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
            # 如果提供了进度值，则单独更新 progress 字段
            if "progress" in info:
                cls.model.update(progress=info["progress"]).where(
                    cls.model.id == id
                ).execute()
            # macOS 分支执行完毕后直接返回，不继续执行后续逻辑
            return

        # 对于非 macOS 平台，在加锁上下文中更新进度，保证并发安全
        with DB.lock("update_progress", -1):
            # 如果提供了进度消息，获取任务对象并拼接新的消息内容，保留最多3000行后更新数据库
            if info["progress_msg"]:
                task = cls.model.get_by_id(id)
                progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
                cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
            # 如果提供了进度值，则单独更新 progress 字段
            if "progress" in info:
                cls.model.update(progress=info["progress"]).where(
                    cls.model.id == id
                ).execute()


# 定义 queue_tasks 函数，用于根据文档类型和配置创建并排队处理任务，支持不同文档类型的差异化处理、任务复用优化等
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
    """Create and queue document processing tasks.
    
    This function creates processing tasks for a document based on its type and configuration.
    It handles different document types (PDF, Excel, etc.) differently and manages task
    chunking and configuration. It also implements task reuse optimization by checking
    for previously completed tasks.
    
    Args:
        doc (dict): Document dictionary containing metadata and configuration.
        bucket (str): Storage bucket name where the document is stored.
        name (str): File name of the document.
        priority (int, optional): Priority level for task queueing (default is 0).
    
    Note:
        - For PDF documents, tasks are created per page range based on configuration
        - For Excel documents, tasks are created per row range
        - Task digests are calculated for optimization and reuse
        - Previous task chunks may be reused if available
    """
    # 定义 new_task 函数，生成一个包含唯一ID、文档ID、进度、起始页码和结束页码的任务字典
    def new_task():
        return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}

    # 初始化空列表用于存储解析任务数组
    parse_task_array = []

    # 如果文档类型是 PDF
    if doc["type"] == FileType.PDF.value:
        # 从 STORAGE_IMPL 获取文件二进制数据，bucket 是知识库 ID，name 是文件名
        file_bin = STORAGE_IMPL.get(bucket, name)
        # 获取是否进行版面识别，默认为 DeepDOC
        do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
        # 获取总页数
        pages = PdfParser.total_page_number(doc["name"], file_bin)
        # 获取每页任务大小，默认为 12
        page_size = doc["parser_config"].get("task_page_size", 12)
        # 如果解析器 ID 是 paper，则调整任务页数大小为 22
        if doc["parser_id"] == "paper":
            page_size = doc["parser_config"].get("task_page_size", 22)
        # 如果解析器 ID 是 one 或 knowledge_graph，或者不使用 DeepDOC 布局识别，则设置极大任务页数（即不分块）
        if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
            page_size = 10 ** 9
        # 获取页面范围，默认为 (1, 10^5)
        page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
        # 遍历所有页面范围，生成对应的任务
        for s, e in page_ranges:
            s -= 1
            # 防止用户输入负数
            s = max(0, s)
            e = min(e - 1, pages)
            # 遍历每个任务，创建new_task，从多少页到多少页
            for p in range(s, e, page_size):
                task = new_task()
                task["from_page"] = p
                task["to_page"] = min(p + page_size, e)
                parse_task_array.append(task)

    # 如果解析器 ID 是 table（Excel 表格），则按行分块生成任务
    elif doc["parser_id"] == "table":
        # 获取文件二进制数据
        file_bin = STORAGE_IMPL.get(bucket, name)
        # 获取总行数
        rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
        # 每 3000 行生成一个任务
        for i in range(0, rn, 3000):
            task = new_task()
            task["from_page"] = i
            task["to_page"] = min(i + 3000, rn)
            parse_task_array.append(task)
    # 其他情况直接生成一个默认任务
    else:
        parse_task_array.append(new_task())

    # 获取文档的切片配置
    chunking_config = DocumentService.get_chunking_config(doc["id"])
    # 对每个任务计算摘要，用于任务重用优化
    for task in parse_task_array:
        # 这行代码出现在任务生成之后、任务入队之前，用于为每一个任务生成一个 唯一的摘要（digest）。这个摘要的目的是：
        hasher = xxhash.xxh64()
        # 对配置字段排序后更新到哈希中
        # 使用 `sorted()` 是为了保证遍历顺序一致，从而确保哈希值在不同运行之间保持一致。
        for field in sorted(chunking_config.keys()):
            if field == "parser_config":
                # 删除不需要参与哈希计算的字段
                # 从配置中剔除与任务“指纹”无关的字段，然后将剩余字段的内容统一参与哈希计算，生成一个用于标识该任务唯一性的摘要（digest）。
                for k in ["raptor", "graphrag"]:
                    # 如果在这几个字段里就要去掉
                    # 相同的输入 → 相同的 digest
                    # 不同的输入 → 不同的 digest
                    if k in chunking_config[field]:
                        del chunking_config[field][k]
            hasher.update(str(chunking_config[field]).encode("utf-8"))
        # 将任务字段也参与哈希计算
        for field in ["doc_id", "from_page", "to_page"]:
            hasher.update(str(task.get(field, "")).encode("utf-8"))
        # 生成最终的摘要值
        task_digest = hasher.hexdigest()
        task["digest"] = task_digest
        task["progress"] = 0.0
        task["priority"] = priority

    # 获取该文档之前的所有任务
    prev_tasks = TaskService.get_tasks(doc["id"])
    ck_num = 0
    if prev_tasks:
        # 尝试复用之前已完成的任务
        for task in parse_task_array:
            ck_num += reuse_prev_task_chunks(task, prev_tasks, chunking_config)
        # 清除该文档的旧任务
        TaskService.filter_delete([Task.doc_id == doc["id"]])
        chunk_ids = []
        for task in prev_tasks:
            if task["chunk_ids"]:
                chunk_ids.extend(task["chunk_ids"].split())
        if chunk_ids:
            # 删除旧的 chunk 数据, 这里默认是es里的数据
            settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]),
                                         chunking_config["kb_id"])
    # 更新文档的 chunk 数量
    DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})

    # 批量插入新任务到数据库
    bulk_insert_into_db(Task, parse_task_array, True)
    # 标记文档状态为开始解析
    DocumentService.begin2parse(doc["id"])

    # 过滤出未完成的任务
    unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
    # 将未完成任务推送到 Redis 队列中
    for unfinished_task in unfinished_task_array:
        assert REDIS_CONN.queue_product(
            get_svr_queue_name(priority), message=unfinished_task
        ), "Can't access Redis. Please check the Redis' status."


# 定义 reuse_prev_task_chunks 函数，用于尝试复用之前任务的 chunk 数据以提高处理效率
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
    """Attempt to reuse chunks from previous tasks for optimization.
    
    This function checks if chunks from previously completed tasks can be reused for
    the current task, which can significantly improve processing efficiency. It matches
    tasks based on page ranges and configuration digests.
    
    Args:
        task (dict): Current task dictionary to potentially reuse chunks for.
        prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
        chunking_config (dict): Configuration dictionary for chunk processing.
    
    Returns:
        int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
    
    Note:
        Chunks can only be reused if:
        - A previous task exists with matching page range and configuration digest
        - The previous task was completed successfully (progress = 1.0)
        - The previous task has valid chunk IDs
    """
    # 初始化索引变量 idx，用于遍历之前的任务列表
    idx = 0
    # 循环查找是否有与当前任务匹配的已完成历史任务（基于 from_page 和 digest）
    while idx < len(prev_tasks):
        prev_task = prev_tasks[idx]
        if prev_task.get("from_page", 0) == task.get("from_page", 0) \
                and prev_task.get("digest", 0) == task.get("digest", ""):
            break
        idx += 1

    # 如果没有找到匹配的任务，返回 0 表示无法复用
    if idx >= len(prev_tasks):
        return 0
    # 获取匹配到的历史任务对象
    prev_task = prev_tasks[idx]
    # 如果该历史任务未完成或没有 chunk_ids，则无法复用，返回 0
    if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
        return 0
    # 将历史任务的 chunk_ids 赋值给当前任务，表示成功复用
    task["chunk_ids"] = prev_task["chunk_ids"]
    # 设置当前任务进度为已完成
    task["progress"] = 1.0
    # 如果当前任务页数跨度较大，设置进度信息前缀包含页码范围
    if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
        task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
    else:
        task["progress_msg"] = ""
    # 添加时间戳和复用提示信息到进度消息中
    task["progress_msg"] = " ".join(
        [datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
    # 清空历史任务中的 chunk_ids，防止重复使用
    prev_task["chunk_ids"] = ""

    # 返回成功复用的 chunk 数量
    return len(task["chunk_ids"].split())
